{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://f000.backblazeb2.com/file/malay-dataset/qa/natural/translated-train.json\n",
    "# !wget https://f000.backblazeb2.com/file/malay-dataset/qa/natural/translated-validation.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "files = ['translated-train.json', 'translated-validation.json']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def cleaning(string):\n",
    "    string = string.replace('\\n', ' ').replace('\\t', ' ')\n",
    "    string = re.sub(r'[ ]+', ' ', string).strip()\n",
    "    return string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "translated-train.json\n",
      "translated-validation.json\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import tensorflow as tf\n",
    "\n",
    "questions, answers = [], []\n",
    "for file in files:\n",
    "    print(file)\n",
    "    with open(file) as fopen:\n",
    "        data = json.load(fopen)\n",
    "\n",
    "    for row in data:\n",
    "        try:\n",
    "            if '<>' not in row:\n",
    "                q, a = row.split('? ')\n",
    "                q = f'{q}?'\n",
    "            else:\n",
    "                q, a = row.split('<>')\n",
    "            questions.append(q.strip())\n",
    "            answers.append(a.strip())\n",
    "        except:\n",
    "            pass\n",
    "    \n",
    "with tf.io.gfile.GFile('qa.tsv', \"w\") as outfile:\n",
    "    for i in range(len(questions)):\n",
    "        cq = cleaning(questions[i])\n",
    "        ca = cleaning(answers[i])\n",
    "        if len(cq) > 5 and len(ca) > 2:\n",
    "            outfile.write(\"%s\\t%s\\n\" % (cq, ca))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "from t5.data import preprocessors as prep\n",
    "import functools\n",
    "import t5\n",
    "import gin\n",
    "\n",
    "gin.parse_config_file('pretrained_models_base_operative_config.gin')\n",
    "vocab = 'sp10m.cased.ms-en.model'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def question_dataset(split, shuffle_files = False):\n",
    "    del shuffle_files\n",
    "    ds = tf.data.TextLineDataset([split])\n",
    "\n",
    "    ds = ds.map(\n",
    "        functools.partial(\n",
    "            tf.io.decode_csv,\n",
    "            record_defaults = ['', ''],\n",
    "            field_delim = '\\t',\n",
    "            use_quote_delim = False,\n",
    "        ),\n",
    "        num_parallel_calls = tf.data.experimental.AUTOTUNE,\n",
    "    )\n",
    "    ds = ds.map(lambda *ex: dict(zip(['question', 'answer'], ex)))\n",
    "    return ds\n",
    "\n",
    "\n",
    "def question_preprocessor(ds):\n",
    "    def to_inputs_and_targets(ex):\n",
    "        return {\n",
    "            'inputs': tf.strings.join(['soalan: ', ex['question']]),\n",
    "            'targets': ex['answer'],\n",
    "        }\n",
    "\n",
    "    return ds.map(\n",
    "        to_inputs_and_targets,\n",
    "        num_parallel_calls = tf.data.experimental.AUTOTUNE,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "t5.data.TaskRegistry.remove('question_dataset')\n",
    "t5.data.TaskRegistry.add(\n",
    "    'question_dataset',\n",
    "    dataset_fn = question_dataset,\n",
    "    splits = ['train'],\n",
    "    text_preprocessor = [question_preprocessor],\n",
    "    sentencepiece_model_path = vocab,\n",
    "    metric_fns = [t5.evaluation.metrics.accuracy],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nq_task = t5.data.TaskRegistry.get(\"question_dataset\")\n",
    "ds = nq_task.get_dataset(split='qa.tsv', sequence_length={\"inputs\": 1024, \"targets\": 1024})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r = tfds.as_numpy(ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "next(r)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
